Vision Transformer (ViT)

Vision Transformer (ViT) is a pure Transformer-based architecture for computer vision that treats an image as a sequence of patches — analogous to tokens in NLP. By discarding convolution’s inductive biases entirely, ViT demonstrates that a standard Transformer encoder, when pre-trained on sufficient data, can match or surpass [[ResNet|CNNs]] on image classification while serving as the architectural foundation for modern generative models like [[DiT]].


1. Core Concept

1.1 “An Image is Worth 16×16 Words”

The key insight of Dosovitskiy et al. (2020) is radical in its simplicity:

Split an image into fixed-size patches, linearly embed each patch, add position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
ViT Architecture Overview
═══════════════════════════════════════════════════════
Input Image (H × W × C, e.g., 224×224×3)

├── Patch Partition: split into N patches of size P×P
│ → N = (H/P) · (W/P) patches, each flattened to P²·C dims

├── Linear Projection (Patch Embedding): P²·C → D
│ → Each patch becomes a D-dimensional token

├── Prepend [CLS] Token (learnable classification embedding)

├── Add Position Embedding (learnable or sinusoidal)


┌─────────────────────────────────────────────────────┐
│ Transformer Encoder × L (e.g., 12 layers) │
│ ┌───────────────────────────────────────────────┐ │
│ │ LayerNorm → Multi-Head Self-Attention → + │ │
│ │ LayerNorm → MLP (GELU) → + │ │
│ └───────────────────────────────────────────────┘ │
│ ... repeat L times ... │
└─────────────────────────────────────────────────────┘

├── Extract [CLS] token from output


MLP Head → Class Prediction
═══════════════════════════════════════════════════════

1.2 The Inductive Bias Trade-off

Property CNN ([[ResNet]]) ViT
Locality Hard-coded via small kernels Learned from data
Translation equivariance Built-in (weight sharing) Learned via position embeddings
Hierarchy Explicit (pooling/strides) Uniform across all layers
Global context Only in deepest layers Every layer (self-attention)
Data efficiency High (strong priors) Low (needs massive data)
Scalability ceiling Plateaus Continues improving

The trade-off: CNNs start better with limited data; ViT overtakes them when pre-trained on large-scale datasets (ImageNet-21k, JFT-300M).

1.3 Mathematical Formulation

Given an image xRH×W×C and patch size P :

Step 1 — Patchify & Embed:

xp=Patchify(x)RN×(P2C),N=HWP2 z0=[xclass;xp1E;xp2E;;xpNE]+Epos

where ER(P2C)×D is the patch embedding matrix, xclassRD is the learnable [CLS] token, and EposR(N+1)×D is the position embedding.

Step 2 — Transformer Encoder (for =1,,L ):

z=MSA(LN(z1))+z1 z=MLP(LN(z))+z

Step 3 — Classification:

y^=LN(zL0)Whead

where zL0 is the [CLS] token at the final layer.


2. Architecture in Detail

2.1 Patch Embedding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
"""Split image into patches and embed them.

Args:
img_size: Input image size (square assumed)
patch_size: Size of each patch
in_channels: Number of input channels (3 for RGB)
embed_dim: Output embedding dimension
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2

# Use Conv2d for efficient patch extraction + linear projection
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)

def forward(self, x):
# x: (B, C, H, W) → (B, embed_dim, H/P, W/P)
x = self.proj(x)
# Flatten spatial dims → (B, embed_dim, N)
x = x.flatten(2)
# Transpose → (B, N, embed_dim)
x = x.transpose(1, 2)
return x

Why Conv2d instead of manual unfolding? A single Conv2d(kernel_size=P, stride=P) simultaneously splits patches and projects them — mathematically equivalent but far more efficient on GPU.

2.2 Position Embedding

ViT uses learned 1D position embeddings (not 2D, not sinusoidal by default):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Standard ViT: learned position embedding
self.pos_embed = nn.Parameter(
torch.randn(1, num_patches + 1, embed_dim) * 0.02
)

# Alternative: 2D sinusoidal (better for variable resolutions)
def get_2d_sincos_pos_embed(embed_dim, grid_size):
"""Generate 2D sinusoidal position embedding.

Used in MAE and DiT for resolution-agnostic positioning.
"""
grid_h = torch.arange(grid_size, dtype=torch.float32)
grid_w = torch.arange(grid_size, dtype=torch.float32)
grid = torch.stack(torch.meshgrid(grid_h, grid_w), axis=-1).reshape(-1, 2)
# ... sinusoidal encoding on both dimensions
return pos_embed # (grid_size², embed_dim)
Position Embedding Type Pros Cons Used In
Learned 1D Simple, effective Fixed resolution Original ViT
Sinusoidal 1D Extrapolates to longer sequences No 2D structure Vanilla Transformer style
Learned 2D Encodes spatial structure More parameters Some ViT variants
Sinusoidal 2D Resolution-agnostic, no params Less expressive MAE, DiT
Relative (RoPE) Captures relative distances Slightly more compute Recent vision models

2.3 Multi-Head Self-Attention (MSA)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MultiHeadAttention(nn.Module):
"""Standard multi-head self-attention for ViT."""

def __init__(self, dim, num_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
# Generate Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]

# Scaled dot-product attention
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, heads, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

Computational cost: O(N2D) — quadratic in the number of patches. For a 224×224 image with P=16: N = 196 tokens → 196² = ~38K attention pairs per head (manageable). But for P=4: N = 3136 tokens → 3136² ≈ 9.8M pairs (10× more expensive).

2.4 Complete ViT Model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class VisionTransformer(nn.Module):
"""Complete ViT model following the original design."""

def __init__(
self, img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.
):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.num_patches

# CLS token + position embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
self.pos_drop = nn.Dropout(drop_rate)

# Transformer blocks
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, drop_rate, attn_drop_rate)
for _ in range(depth)
])

# Normalization + classification head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)

self._init_weights()

def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # (B, N, D)

# Prepend CLS token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (B, N+1, D)
x = x + self.pos_embed
x = self.pos_drop(x)

# Transformer encoder
for blk in self.blocks:
x = blk(x)

# Extract CLS token and classify
x = self.norm(x)
return self.head(x[:, 0])

2.5 ViT Model Variants

Model Layers Hidden Dim Heads MLP Size Params Patch
ViT-Tiny 12 192 3 768 5.7M 16
ViT-Small 12 384 6 1536 22M 16
ViT-Base 12 768 12 3072 86M 16
ViT-Large 24 1024 16 4096 307M 16
ViT-Huge 32 1280 16 5120 632M 14

3. Key Design Principles

3.1 Why Self-Attention for Vision?

Property How ViT Achieves It
Global receptive field Every patch attends to every other patch — even in layer 1
Dynamic weighting Attention weights are input-dependent (unlike fixed convolution kernels)
Content-based interaction Similar patches attend to each other regardless of spatial distance
Scalable capacity More data → more meaningful attention patterns learned
Multi-modal unification Same architecture for images, text, audio (foundation models)

3.2 The [CLS] Token Design

ViT borrows BERT’s [CLS] token convention:

  • A learnable embedding prepended to the patch sequence
  • After L Transformer layers, the [CLS] token’s output serves as the global image representation
  • Alternative: Global Average Pooling (GAP) over all patch tokens — used in some variants
1
2
3
4
5
6
7
8
9
[CLS] Design:
Input: [CLS] [Patch₁] [Patch₂] ... [Patch_N]
↓ ↓ ↓ ↓
After L Transformer layers:
↓ ↓ ↓ ↓
Output: [CLS]ₗ [Patch₁]ₗ [Patch₂]ₗ ... [Patch_N]ₗ


Classification Head → "Golden Retriever"

GAP vs [CLS]: In practice, both work similarly. [CLS] is the original design choice; GAP is simpler and used in DeiT, CAE, and some modern variants.


4. ViT Variants and Evolution

4.1 Family Tree

1
2
3
4
5
6
7
8
9
10
11
12
13
ViT (Dosovitskiy et al., 2020)
├── DeiT (Touvron et al., 2020)
│ └── Data-efficient training via distillation
├── Swin Transformer (Liu et al., 2021)
│ └── Hierarchical, shifted windows → linear complexity
├── MAE (He et al., 2022)
│ └── Masked autoencoding → self-supervised pre-training
├── DINO / DINOv2 (Caron et al., 2021/2023)
│ └── Self-distillation → emergent segmentation
├── ViT-Adapter / ViTDet (2022/2023)
│ └── ViT for dense prediction (detection, segmentation)
└── DiT (Peebles & Xie, 2023)
└── ViT as diffusion backbone → SORA, SD3

4.2 DeiT: Data-Efficient ViT

Problem: Original ViT needs 300M+ images (JFT) to beat ResNet.

DeiT solution: Knowledge distillation from a CNN teacher:

LDeiT=LCE(y,y^)+λLdistill(y^teacher,y^student)
  • Trained on ImageNet-1K only (1.2M images)
  • Distillation token alongside [CLS] token
  • Achieves ViT-level performance without massive pre-training

4.3 Swin Transformer: Hierarchical ViT

Swin introduces shifted window attention to achieve linear complexity:

Aspect ViT Swin Transformer
Token resolution Constant (coarse) Hierarchical (fine → coarse)
Attention scope Global (all tokens) Local windows (shifted)
Complexity (\mathcal{O}(N^2)) (\mathcal{O}(N))
Dense prediction Requires adaptation Native (FPN-compatible)
Use case Classification, generation Detection, segmentation

4.4 MAE: Masked Autoencoder

Inspired by BERT’s masked language modeling:

  • Mask 75% of patches (not 15% like BERT — images are more redundant)
  • Encoder processes only visible patches (efficient)
  • Decoder reconstructs masked patches from encoded visible tokens + mask tokens
  • Pre-training objective: MSE between predicted and original pixels
1
2
3
4
5
6
7
8
9
10
11
# MAE masking (simplified)
def random_masking(x, mask_ratio=0.75):
"""Randomly mask patches for MAE pre-training."""
N = x.shape[1] # number of patches
len_keep = int(N * (1 - mask_ratio))

noise = torch.rand(N, device=x.device)
ids_shuffle = torch.argsort(noise)
ids_keep = ids_shuffle[:len_keep]

return x[:, ids_keep, :], ids_keep, ids_shuffle[len_keep:]

4.5 DINO: Self-Supervised ViT

DINO (self-DIstillation with NO labels) shows that ViT trained with self-supervision automatically learns semantic segmentation:

  • Teacher-student framework with momentum encoder
  • ViT’s attention maps naturally highlight objects
  • DINOv2 extends this to billion-scale data → universal visual features

5. ViT as Foundation for Generative Models

5.1 ViT → [[DiT]]

[[DiT]] directly inherits ViT’s design for diffusion-based generation:

Component ViT [[DiT]]
Input Image Noisy latent (VAE-compressed)
Patch embed Conv2d projection Same
Position embed Learned 1D Sinusoidal 2D (resolution-agnostic)
Transformer block Standard (LN → MSA → LN → MLP) adaLN-conditioned variant
Output Class logits Predicted noise ε
Conditioning None (class token only) Time + text via adaLN

Why ViT works for diffusion:

  1. Global attention captures long-range structure crucial for image generation
  2. No architectural bottleneck — full-resolution processing throughout
  3. Scalability — [[DiT]] shows power-law improvement with model size, which U-Net doesn’t match
  4. Unified ecosystem — same infrastructure as LLMs (FlashAttention, FSDP, etc.)

5.2 Patch Size in Generative Context

In DiT, the patch size p controls the compute-quality trade-off:

Latent of 32×32 , embed_dim D :

  • p=1 : 1024 tokens → heavy but maximum detail
  • p=2 : 256 tokens → default (best trade-off, DiT-XL/2)
  • p=4 : 64 tokens → fast, some quality loss
  • p=8 : 16 tokens → very fast, coarse results

5.3 Beyond DiT: SORA and Video Generation

SORA extends the ViT→DiT lineage to video by using 3D spacetime patches:

VideoRT×H×W×CPatchify(pt,ph,pw)TokensRN×D

This is a direct generalization: ViT treats an image as 2D patches; SORA treats a video as 3D spacetime patches.


6. Training and Transfer Learning

6.1 Pre-training Strategies

Strategy Data Required Performance Example
Supervised (JFT-300M) 300M labeled images Best Original ViT
Supervised (ImageNet-21K) 14M labeled images Strong ViT-B/16 @ 384
Supervised + Distillation (ImageNet-1K) 1.2M labeled images Good DeiT
Self-supervised (MAE) Unlabeled images Strong MAE ViT-L
Self-supervised (DINOv2) 142M curated images SOTA features DINOv2 ViT-g

6.2 Fine-tuning at Higher Resolution

ViT can be fine-tuned on resolutions different from pre-training:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def resize_pos_embed(pos_embed, new_num_patches):
"""Interpolate position embedding for higher-resolution input.

When fine-tuning ViT at 384×384 (trained at 224×224):
- Original: 14×14 = 196 patch tokens
- New: 24×24 = 576 patch tokens
- Interpolate pos_embed from (1, 197, D) to (1, 577, D)
"""
cls_embed = pos_embed[:, :1, :] # CLS token (unchanged)
patch_embed = pos_embed[:, 1:, :] # Patch tokens

old_size = int(patch_embed.shape[1] ** 0.5)
new_size = int(new_num_patches ** 0.5)

# Reshape to 2D and interpolate
patch_embed = patch_embed.reshape(1, old_size, old_size, -1)
patch_embed = patch_embed.permute(0, 3, 1, 2) # (1, D, H, W)
patch_embed = F.interpolate(
patch_embed, size=(new_size, new_size), mode='bicubic'
)
patch_embed = patch_embed.permute(0, 2, 3, 1).reshape(1, -1, pos_embed.shape[-1])

return torch.cat([cls_embed, patch_embed], dim=1)

6.3 Key Hyperparameters

Parameter ViT-B/16 ViT-L/16 Notes
Optimizer AdamW AdamW β₁=0.9, β₂=0.999
Learning rate 3e-3 3e-3 Cosine schedule
Weight decay 0.3 0.3 Excludes bias & LN
Batch size 4096 4096 Large batches crucial
Epochs 300 300 + warmup (10K steps)
Dropout 0.1 0.1 Reduced for larger data
Gradient clipping 1.0 1.0 Global norm

7. Comparison: ViT vs CNN vs Hybrid

7.1 Performance Landscape

Architecture ImageNet Top-1 Params FLOPs Pre-train Data
ResNet-50 76.1% 25M 4.1G ImageNet-1K
ResNet-152 78.3% 60M 11.6G ImageNet-1K
EfficientNet-B7 84.3% 66M 38G ImageNet-1K
ViT-B/16 77.9% / 84.0% 86M 17.6G IN-1K / IN-21K
ViT-L/16 76.5% / 85.2% 307M 61.6G IN-1K / JFT-300M
DeiT-B 81.8% 86M 17.6G ImageNet-1K

ViT underperforms on ImageNet-1K alone but dominates when pre-trained on larger datasets — reflecting the inductive bias trade-off.

7.2 When to Use What

Scenario Recommendation Reason
Small dataset (<100K) CNN or DeiT ViT overfits without massive data
Medium dataset (100K–1M) ViT + strong augmentation Or pre-trained ViT + fine-tuning
Large dataset (>1M) ViT Full potential unlocked
Dense prediction Swin / ConvNeXt Hierarchical features needed
Generative modeling DiT (ViT-based) Scalability + global attention
Multi-modal ViT Unified Transformer ecosystem
Edge deployment MobileNet / EfficientNet ViT too heavy without optimization

8. Scaling Properties

8.1 ViT Scaling Behavior

ViT exhibits three-phase scaling:

1
2
3
4
5
6
7
8
9
10
11
12
Performance vs Data Size (ViT):
Acc ↑
│ ╱────────────── ViT-L (Saturates late)
│ ╱
│ ╱
│ ╱──╱ ViT-B
│ ╱
│ ╱
│╱ ResNet-152 (Saturates early)

└──────────────────────→ Data Size
1M 10M 100M 1B
  • Phase 1 (small data): [[ResNet]] wins — strong inductive biases compensate for data scarcity
  • Phase 2 (medium data): ViT catches up — self-attention starts learning useful patterns
  • Phase 3 (large data): ViT dominates — global attention unlocks superior representations

8.2 Compute-Optimal ViT

Following the Chinchilla scaling principles, for a given compute budget C :

NoptimalC0.5,DoptimalC0.5

where N is model parameters and D is training tokens (image-equivalent). Both model size and data should scale together.


9. Theoretical Understanding

9.1 Attention as a Learnable Filter

While CNN uses fixed local filters, ViT’s self-attention computes:

Attention(Q,K,V)=softmax(QKdk)V

The attention matrix Aij represents how much patch i attends to patch j — learned from data, not hard-coded by architecture.

9.2 Why ViT Needs More Data

Theoretical explanation: ViT has weaker inductive biases than CNN.

  • CNN’s convolution is equivariant to translation by design: f(T(x))=T(f(x))
  • ViT must learn translation-equivariant behavior from data alone
  • This requires seeing objects at diverse positions — hence more data

The sample complexity of learning translation invariance from scratch is significantly higher than having it built into the architecture.

9.3 Fourier Analysis Perspective

Recent work (Park & Kim, 2022) shows that ViT’s self-attention acts as a low-pass filter in early layers and progressively allows higher frequencies in deeper layers — analogous to CNN’s hierarchical feature learning, but achieved through learned attention patterns rather than architectural constraints.


10. Core Formula Cards

# Formula Meaning
z0=[xclass;xpE]+Epos Patch embedding + position encoding
z=MSA(LN(z1))+z1 Residual self-attention
z=MLP(LN(z))+z Residual MLP
Attention(Q,K,V)=softmax(QKdk)V Scaled dot-product attention
MSA(X)=Concat(head1,,headh)WO Multi-head aggregation
y^=LN(zL0)Whead Classification from [CLS] token

11. Summary

ViT’s core contribution is a philosophical shift in computer vision:

Era Paradigm Philosophy
Pre-2020 CNNs “Hard-code what we know about vision (locality, translation invariance)”
Post-2020 ViT/Transformers “Provide capacity + data; let the model learn visual structure”

This shift enabled:

  • Unified architectures across vision, language, and audio
  • Scalable pre-training (MAE, DINO, CLIP) rivaling NLP’s BERT/GPT moment
  • Generative models at scale — [[DiT]], SORA, Stable Diffusion 3 all inherit ViT’s design

ViT is not just an architecture — it’s the bridge that brought vision into the Transformer era, enabling the same scaling laws and infrastructure that revolutionized NLP to transform computer vision and generative modeling.


Dataview Query

1
2
3
4
LIST
FROM #vit OR #vision_transformer OR #self_attention
WHERE type = "architecture"
SORT file.ctime DESC

  • [[ResNet]]
  • [[DiT]]
  • [[Diffusion Model]]
  • [[U-Net]]
  • [[Transformer]]
  • [[RoPE]]
  • [[Neural ODE]]
  • [[Convolutional Neural Network (CNN)]]

References

  • Paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (Dosovitskiy et al., ICLR 2021 — Oral)
  • Paper: Training data-efficient image transformers & distillation through attention (Touvron et al., ICML 2021 — DeiT)
  • Paper: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows (Liu et al., ICCV 2021 — Best Paper)
  • Paper: Masked Autoencoders Are Scalable Vision Learners (He et al., CVPR 2022)
  • Paper: Emerging Properties in Self-Supervised Vision Transformers (Caron et al., ICCV 2021 — DINO)
  • Paper: Scalable Diffusion Models with Transformers (Peebles & Xie, ICCV 2023 — [[DiT]])
  • Paper: Attention Is All You Need (Vaswani et al., NeurIPS 2017 — Original Transformer)
  • Blog: Vision Transformer (ViT) — A Flutter on ImageNet — lucidrains
  • Blog: The Illustrated Vision Transformer — Jay Alammar
  • Code: https://github.com/huggingface/pytorch-image-models (timm)
  • Code: https://github.com/facebookresearch/dino (DINO)
  • Code: https://github.com/facebookresearch/mae (MAE)
  • Course: CS231n Convolutional Neural Networks for Visual Recognition (Stanford)
  • Course: CS25 Transformers United (Stanford)
    "